Commit 319d9d8b authored by yuhai's avatar yuhai
Browse files

Initial commit

parents
/*
Copyright 2016 Medcl (m AT medcl.net)
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import (
"net/http"
"github.com/parnurzeal/gorequest"
log "github.com/cihub/seelog"
"io/ioutil"
"io"
"errors"
"bytes"
"net/url"
)
func Get(url string,auth *Auth,proxy string) (*http.Response, string, []error) {
request := gorequest.New()
if(auth!=nil){
request.SetBasicAuth(auth.User,auth.Pass)
}
request.Header["Content-Type"]= "application/json"
if(len(proxy)>0){
request.Proxy(proxy)
}
resp, body, errs := request.Get(url).End()
return resp, body, errs
}
func Post(url string,auth *Auth, body string,proxy string)(*http.Response, string, []error) {
request := gorequest.New()
if(auth!=nil){
request.SetBasicAuth(auth.User,auth.Pass)
}
request.Header["Content-Type"]= "application/json"
if(len(proxy)>0){
request.Proxy(proxy)
}
request.Post(url)
if(len(body)>0){
request.Send(body)
}
return request.End()
}
func newDeleteRequest(client *http.Client,method, urlStr string) (*http.Request, error) {
if method == "" {
// We document that "" means "GET" for Request.Method, and people have
// relied on that from NewRequest, so keep that working.
// We still enforce validMethod for non-empty methods.
method = "GET"
}
u, err := url.Parse(urlStr)
if err != nil {
return nil, err
}
req := &http.Request{
Method: method,
URL: u,
Proto: "HTTP/1.1",
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
Host: u.Host,
}
return req, nil
}
func Request(method string,r string,auth *Auth,body *bytes.Buffer,proxy string)(string,error) {
var client *http.Client
client = &http.Client{}
if(len(proxy)>0){
proxyURL, err := url.Parse(proxy)
if(err!=nil){
log.Error(err)
}else{
transport := &http.Transport{Proxy: http.ProxyURL(proxyURL)}
client = &http.Client{Transport: transport}
}
}
var reqest *http.Request
if(body!=nil){
reqest, _ =http.NewRequest(method,r,body)
}else{
reqest, _ = newDeleteRequest(client,method,r)
}
if(auth!=nil){
reqest.SetBasicAuth(auth.User,auth.Pass)
}
reqest.Header.Set("Content-Type", "application/json")
resp,errs := client.Do(reqest)
if errs != nil {
log.Error(errs)
return "",errs
}
if resp.StatusCode != 200 {
b, _ := ioutil.ReadAll(resp.Body)
return "",errors.New("server error: "+string(b))
}
respBody,err:=ioutil.ReadAll(resp.Body)
if err != nil {
log.Error(err)
return string(respBody),err
}
log.Trace(r,string(respBody))
if err != nil {
return string(respBody),err
}
io.Copy(ioutil.Discard, resp.Body)
defer resp.Body.Close()
return string(respBody),nil
}
/*
Copyright 2016 Medcl (m AT medcl.net)
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import (
"strings"
log "github.com/cihub/seelog"
)
func setInitLogging(logLevel string) {
logLevel = strings.ToLower(logLevel)
testConfig := `
<seelog type="sync" minlevel="`
testConfig = testConfig + logLevel
testConfig = testConfig + `">
<outputs formatid="main">
<filter levels="error">
<file path="./log/gopa.log"/>
</filter>
<console formatid="main" />
</outputs>
<formats>
<format id="main" format="[%Date(01-02) %Time] [%LEV] [%File:%Line,%FuncShort] %Msg%n"/>
</formats>
</seelog>`
logger, err := log.LoggerFromConfigAsString(testConfig)
if err != nil {
log.Error("init config error,", err)
}
err = log.ReplaceLogger(logger)
if err != nil {
log.Error("init config error,", err)
}
}
package main
import (
"encoding/json"
"fmt"
"runtime"
"strings"
"sync"
"time"
"bufio"
log "github.com/cihub/seelog"
goflags "github.com/jessevdk/go-flags"
pb "gopkg.in/cheggaaa/pb.v1"
"os"
"io"
)
func main() {
runtime.GOMAXPROCS(runtime.NumCPU())
c := &Config{}
migrator:=Migrator{}
migrator.Config=c
// parse args
_, err := goflags.Parse(c)
if err != nil {
log.Error(err)
return
}
setInitLogging(c.LogLevel)
if len(c.SourceEs) == 0 && len(c.DumpInputFile) == 0 {
log.Error("no input, type --help for more details")
return
}
if len(c.TargetEs) == 0 && len(c.DumpOutFile) == 0 {
log.Error("no output, type --help for more details")
return
}
if c.SourceEs == c.TargetEs && c.SourceIndexNames == c.TargetIndexName {
log.Error("migration output is the same as the output")
return
}
// enough of a buffer to hold all the search results across all workers
migrator.DocChan = make(chan map[string]interface{}, c.DocBufferCount*c.Workers*10)
var srcESVersion *ClusterVersion
// create a progressbar and start a docCount
var outputBar *pb.ProgressBar
var fetchBar = pb.New(1).Prefix("Scroll")
wg := sync.WaitGroup{}
//dealing with input
if len(c.SourceEs) > 0 {
//dealing with basic auth
if len(c.SourceEsAuthStr) > 0 && strings.Contains(c.SourceEsAuthStr, ":") {
authArray := strings.Split(c.SourceEsAuthStr, ":")
auth := Auth{User: authArray[0], Pass: authArray[1]}
migrator.SourceAuth = &auth
}
//get source es version
srcESVersion, errs := migrator.ClusterVersion(c.SourceEs, migrator.SourceAuth,migrator.Config.SourceProxy)
if errs != nil {
return
}
if strings.HasPrefix(srcESVersion.Version.Number, "5.") {
log.Debug("source es is V5,", srcESVersion.Version.Number)
api := new(ESAPIV5)
api.Host = c.SourceEs
api.Auth = migrator.SourceAuth
api.HttpProxy=migrator.Config.SourceProxy
migrator.SourceESAPI = api
} else {
log.Debug("source es is not V5,", srcESVersion.Version.Number)
api := new(ESAPIV0)
api.Host = c.SourceEs
api.Auth = migrator.SourceAuth
api.HttpProxy=migrator.Config.SourceProxy
migrator.SourceESAPI = api
}
if(c.ScrollSliceSize<1){c.ScrollSliceSize=1}
fetchBar.ShowBar=false
totalSize:=0;
finishedSlice:=0
for slice:=0;slice<c.ScrollSliceSize ;slice++ {
scroll, err := migrator.SourceESAPI.NewScroll(c.SourceIndexNames, c.ScrollTime, c.DocBufferCount, c.Query,slice,c.ScrollSliceSize, c.Fields)
if err != nil {
log.Error(err)
return
}
totalSize+=scroll.Hits.Total
if scroll != nil && scroll.Hits.Docs != nil {
if scroll.Hits.Total == 0 {
log.Error("can't find documents from source.")
return
}
go func() {
wg.Add(1)
//process input
// start scroll
scroll.ProcessScrollResult(&migrator, fetchBar)
// loop scrolling until done
for scroll.Next(&migrator, fetchBar) == false {
}
fetchBar.Finish()
// finished, close doc chan and wait for goroutines to be done
wg.Done()
finishedSlice++
//clean up final results
if(finishedSlice==c.ScrollSliceSize){
log.Debug("closing doc chan")
close(migrator.DocChan)
}
}()
}
}
if(totalSize>0){
fetchBar.Total=int64(totalSize)
fetchBar.ShowBar=true
outputBar = pb.New(totalSize).Prefix("Output ")
}
} else if len(c.DumpInputFile) > 0 {
//read file stream
wg.Add(1)
f, err := os.Open(c.DumpInputFile)
if err != nil {
log.Error(err)
return
}
//get file lines
lineCount := 0
defer f.Close()
r := bufio.NewReader(f)
for{
_,err := r.ReadString('\n')
if io.EOF == err || nil != err{
break
}
lineCount += 1
}
log.Trace("file line,", lineCount)
fetchBar := pb.New(lineCount).Prefix("Read")
outputBar = pb.New(lineCount).Prefix("Output ")
f.Close()
go migrator.NewFileReadWorker(fetchBar,&wg)
}
// start pool
pool, err := pb.StartPool(fetchBar, outputBar)
if err != nil {
panic(err)
}
//dealing with output
if len(c.TargetEs) > 0 {
if len(c.TargetEsAuthStr) > 0 && strings.Contains(c.TargetEsAuthStr, ":") {
authArray := strings.Split(c.TargetEsAuthStr, ":")
auth := Auth{User: authArray[0], Pass: authArray[1]}
migrator.TargetAuth = &auth
}
//get target es version
descESVersion, errs := migrator.ClusterVersion(c.TargetEs, migrator.TargetAuth,migrator.Config.TargetProxy)
if errs != nil {
return
}
if strings.HasPrefix(descESVersion.Version.Number, "5.") {
log.Debug("target es is V5,", descESVersion.Version.Number)
api := new(ESAPIV5)
api.Host = c.TargetEs
api.Auth = migrator.TargetAuth
api.HttpProxy=migrator.Config.TargetProxy
migrator.TargetESAPI = api
} else {
log.Debug("target es is not V5,", descESVersion.Version.Number)
api := new(ESAPIV0)
api.Host = c.TargetEs
api.Auth = migrator.TargetAuth
api.HttpProxy=migrator.Config.TargetProxy
migrator.TargetESAPI = api
}
log.Debug("start process with mappings")
if srcESVersion != nil && c.CopyIndexMappings && descESVersion.Version.Number[0] != srcESVersion.Version.Number[0] {
log.Error(srcESVersion.Version, "=>", descESVersion.Version, ",cross-big-version mapping migration not avaiable, please update mapping manually :(")
return
}
// wait for cluster state to be okay before moving
timer := time.NewTimer(time.Second * 3)
for {
if len(c.SourceEs) > 0 {
if status, ready := migrator.ClusterReady(migrator.SourceESAPI); !ready {
log.Infof("%s at %s is %s, delaying migration ", status.Name, c.SourceEs, status.Status)
<-timer.C
continue
}
}
if len(c.TargetEs) > 0 {
if status, ready := migrator.ClusterReady(migrator.TargetESAPI); !ready {
log.Infof("%s at %s is %s, delaying migration ", status.Name, c.TargetEs, status.Status)
<-timer.C
continue
}
}
timer.Stop()
break
}
if len(c.SourceEs) > 0 {
// get all indexes from source
indexNames, indexCount, sourceIndexMappings, err := migrator.SourceESAPI.GetIndexMappings(c.CopyAllIndexes, c.SourceIndexNames)
if err != nil {
log.Error(err)
return
}
sourceIndexRefreshSettings := map[string]interface{}{}
log.Debugf("indexCount: %d",indexCount)
if indexCount > 0 {
//override indexnames to be copy
c.SourceIndexNames = indexNames
// copy index settings if user asked
if c.CopyIndexSettings || c.ShardsCount > 0 {
log.Info("start settings/mappings migration..")
//get source index settings
var sourceIndexSettings *Indexes
sourceIndexSettings, err := migrator.SourceESAPI.GetIndexSettings(c.SourceIndexNames)
log.Debug("source index settings:", sourceIndexSettings)
if err != nil {
log.Error(err)
return
}
//get target index settings
targetIndexSettings, err := migrator.TargetESAPI.GetIndexSettings(c.TargetIndexName)
if err != nil {
//ignore target es settings error
log.Debug(err)
}
log.Debug("target IndexSettings", targetIndexSettings)
//if there is only one index and we specify the dest indexname
if (c.SourceIndexNames != c.TargetIndexName && (len(c.TargetIndexName) > 0) && indexCount == 1 ) {
log.Debugf("only one index,so we can rewrite indexname, src:%v, dest:%v ,indexCount:%d",c.SourceIndexNames,c.TargetIndexName,indexCount)
(*sourceIndexSettings)[c.TargetIndexName] = (*sourceIndexSettings)[c.SourceIndexNames]
delete(*sourceIndexSettings, c.SourceIndexNames)
log.Debug(sourceIndexSettings)
}
// dealing with indices settings
for name, idx := range *sourceIndexSettings {
log.Debug("dealing with index,name:", name, ",settings:", idx)
tempIndexSettings := getEmptyIndexSettings()
targetIndexExist := false
//if target index settings is exist and we don't copy settings, we use target settings
if targetIndexSettings != nil {
//if target es have this index and we dont copy index settings
if val, ok := (*targetIndexSettings)[name]; ok {
targetIndexExist = true
tempIndexSettings = val.(map[string]interface{})
}
if c.RecreateIndex {
migrator.TargetESAPI.DeleteIndex(name)
targetIndexExist = false
}
}
//copy index settings
if c.CopyIndexSettings {
tempIndexSettings = ((*sourceIndexSettings)[name]).(map[string]interface{})
}
//check map elements
if _, ok := tempIndexSettings["settings"]; !ok {
tempIndexSettings["settings"] = map[string]interface{}{}
}
if _, ok := tempIndexSettings["settings"].(map[string]interface{})["index"]; !ok {
tempIndexSettings["settings"].(map[string]interface{})["index"] = map[string]interface{}{}
}
sourceIndexRefreshSettings[name] = ((*sourceIndexSettings)[name].(map[string]interface{}))["settings"].(map[string]interface{})["index"].(map[string]interface{})["refresh_interval"]
//set refresh_interval
tempIndexSettings["settings"].(map[string]interface{})["index"].(map[string]interface{})["refresh_interval"] = -1
tempIndexSettings["settings"].(map[string]interface{})["index"].(map[string]interface{})["number_of_replicas"] = 0
//clean up settings
delete(tempIndexSettings["settings"].(map[string]interface{})["index"].(map[string]interface{}), "number_of_shards")
//copy indexsettings and mappings
if targetIndexExist {
log.Debug("update index with settings,", name, tempIndexSettings)
//override shard settings
if c.ShardsCount > 0 {
tempIndexSettings["settings"].(map[string]interface{})["index"].(map[string]interface{})["number_of_shards"] = c.ShardsCount
}
err := migrator.TargetESAPI.UpdateIndexSettings(name, tempIndexSettings)
if err != nil {
log.Error(err)
}
} else {
//override shard settings
if c.ShardsCount > 0 {
tempIndexSettings["settings"].(map[string]interface{})["index"].(map[string]interface{})["number_of_shards"] = c.ShardsCount
}
log.Debug("create index with settings,", name, tempIndexSettings)
err := migrator.TargetESAPI.CreateIndex(name, tempIndexSettings)
if err != nil {
log.Error(err)
}
}
}
if c.CopyIndexMappings {
//if there is only one index and we specify the dest indexname
if (c.SourceIndexNames != c.TargetIndexName && (len(c.TargetIndexName) > 0) && indexCount == 1 ) {
log.Debugf("only one index,so we can rewrite indexname, src:%v, dest:%v ,indexCount:%d",c.SourceIndexNames,c.TargetIndexName,indexCount)
(*sourceIndexMappings)[c.TargetIndexName] = (*sourceIndexMappings)[c.SourceIndexNames]
delete(*sourceIndexMappings, c.SourceIndexNames)
log.Debug(sourceIndexMappings)
}
for name, mapping := range *sourceIndexMappings {
err := migrator.TargetESAPI.UpdateIndexMapping(name, mapping.(map[string]interface{})["mappings"].(map[string]interface{}))
if err != nil {
log.Error(err)
}
}
}
log.Info("settings/mappings migration finished.")
}
} else {
log.Error("index not exists,", c.SourceIndexNames)
return
}
defer migrator.recoveryIndexSettings(sourceIndexRefreshSettings)
} else if len(c.DumpInputFile) > 0 {
//check shard settings
//TODO support shard config
}
}
log.Info("start data migration..")
//start es bulk thread
if len(c.TargetEs) > 0 {
log.Debug("start es bulk workers")
outputBar.Prefix("Bulk")
var docCount int
wg.Add(c.Workers)
for i := 0; i < c.Workers; i++ {
go migrator.NewBulkWorker(&docCount, outputBar, &wg)
}
} else if len(c.DumpOutFile) > 0 {
// start file write
outputBar.Prefix("Write")
wg.Add(1)
go migrator.NewFileDumpWorker(outputBar, &wg)
}
wg.Wait()
outputBar.Finish()
// close pool
pool.Stop()
log.Info("data migration finished.")
}
func (c *Migrator) recoveryIndexSettings(sourceIndexRefreshSettings map[string]interface{}) {
//update replica and refresh_interval
for name, interval := range sourceIndexRefreshSettings {
tempIndexSettings := getEmptyIndexSettings()
tempIndexSettings["settings"].(map[string]interface{})["index"].(map[string]interface{})["refresh_interval"] = interval
//tempIndexSettings["settings"].(map[string]interface{})["index"].(map[string]interface{})["number_of_replicas"] = 0
c.TargetESAPI.UpdateIndexSettings(name, tempIndexSettings)
if c.Config.Refresh {
c.TargetESAPI.Refresh(name)
}
}
}
func (c *Migrator) ClusterVersion(host string, auth *Auth,proxy string) (*ClusterVersion, []error) {
url := fmt.Sprintf("%s", host)
_, body, errs := Get(url, auth,proxy)
if errs != nil {
log.Error(errs)
return nil, errs
}
log.Debug(body)
version := &ClusterVersion{}
err := json.Unmarshal([]byte(body), version)
if err != nil {
log.Error(body, errs)
return nil, errs
}
return version, nil
}
func (c *Migrator) ClusterReady(api ESAPI) (*ClusterHealth, bool) {
health := api.ClusterHealth()
if health.Status == "red" {
return health, false
}
if c.Config.WaitForGreen == false && health.Status == "yellow" {
return health, true
}
if health.Status == "green" {
return health, true
}
return health, false
}
/*
Copyright 2016 Medcl (m AT medcl.net)
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import
(
"encoding/json"
log "github.com/cihub/seelog"
"testing"
)
func TestParse(test *testing.T){
setInitLogging("debug")
text:= `{ "_scroll_id": "c2NhbjswOzE7dG90YWxfaGl0czoxODY1MjY5Ow==", "took": 1, "timed_out": false, "_shards": { "total": 1, "successful": 0, "failed": 1, "failures": [ { "shard": -1, "index": null } ] }, "hits": { "total": 1865269, "max_score": 0, "hits": [] } }`
scroll := Scroll{}
err:=json.Unmarshal([]byte(text),&scroll)
if err != nil {
log.Error(err)
return
}
log.Info(scroll.ScrollId)
}
/*
Copyright 2016 Medcl (m AT medcl.net)
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import (
"gopkg.in/cheggaaa/pb.v1"
"encoding/json"
log "github.com/cihub/seelog"
)
// Stream from source es instance. "done" is an indicator that the stream is
// over
func (s *Scroll) ProcessScrollResult(c *Migrator, bar *pb.ProgressBar){
//update progress bar
bar.Add(len(s.Hits.Docs))
// show any failures
for _, failure := range s.Shards.Failures {
reason, _ := json.Marshal(failure.Reason)
log.Errorf(string(reason))
}
// write all the docs into a channel
for _, docI := range s.Hits.Docs {
c.DocChan <- docI.(map[string]interface{})
}
}
func (s *Scroll) Next(c *Migrator, bar *pb.ProgressBar) (done bool) {
scroll,err:=c.SourceESAPI.NextScroll(c.Config.ScrollTime,s.ScrollId)
if err != nil {
log.Error(err)
return false
}
if scroll.Hits.Docs == nil || len(scroll.Hits.Docs) <= 0 {
log.Debug("scroll result is empty")
return true
}
scroll.ProcessScrollResult(c,bar)
//update scrollId
s.ScrollId=scroll.ScrollId
return
}
/*
Copyright 2016 Medcl (m AT medcl.net)
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
log "github.com/cihub/seelog"
"regexp"
"strings"
)
type ESAPIV0 struct {
Host string //eg: http://localhost:9200
Auth *Auth //eg: user:pass
HttpProxy string //eg: http://proxyIp:proxyPort
}
func (s *ESAPIV0) ClusterHealth() *ClusterHealth {
url := fmt.Sprintf("%s/_cluster/health", s.Host)
_, body, errs := Get(url, s.Auth,s.HttpProxy)
if errs != nil {
return &ClusterHealth{Name: s.Host, Status: "unreachable"}
}
log.Debug(url)
log.Debug(body)
health := &ClusterHealth{}
err := json.Unmarshal([]byte(body), health)
if err != nil {
log.Error(body)
return &ClusterHealth{Name: s.Host, Status: "unreachable"}
}
return health
}
func (s *ESAPIV0) Bulk(data *bytes.Buffer) {
if data == nil || data.Len() == 0 {
return
}
data.WriteRune('\n')
url := fmt.Sprintf("%s/_bulk", s.Host)
body,err:=Request("POST",url,s.Auth,data,s.HttpProxy)
if err != nil {
log.Error(err)
return
}
log.Trace(url,string(body))
data.Reset()
}
func (s *ESAPIV0) GetIndexSettings(indexNames string) (*Indexes, error) {
// get all settings
allSettings := &Indexes{}
url := fmt.Sprintf("%s/%s/_settings", s.Host, indexNames)
resp, body, errs := Get(url, s.Auth,s.HttpProxy)
if errs != nil {
return nil, errs[0]
}
io.Copy(ioutil.Discard, resp.Body)
defer resp.Body.Close()
if resp.StatusCode != 200 {
return nil, errors.New(body)
}
log.Debug(body)
err := json.Unmarshal([]byte(body), allSettings)
if err != nil {
panic(err)
return nil, err
}
return allSettings, nil
}
func (s *ESAPIV0) GetIndexMappings(copyAllIndexes bool, indexNames string) (string, int, *Indexes, error) {
url := fmt.Sprintf("%s/%s/_mapping", s.Host, indexNames)
resp, body, errs := Get(url, s.Auth,s.HttpProxy)
if errs != nil {
log.Error(errs)
return "", 0, nil, errs[0]
}
io.Copy(ioutil.Discard, resp.Body)
defer resp.Body.Close()
if resp.StatusCode != 200 {
return "", 0, nil, errors.New(body)
}
idxs := Indexes{}
er := json.Unmarshal([]byte(body), &idxs)
if er != nil {
log.Error(body)
return "", 0, nil, er
}
// remove indexes that start with . if user asked for it
//if copyAllIndexes == false {
// for name := range idxs {
// switch name[0] {
// case '.':
// delete(idxs, name)
// case '_':
// delete(idxs, name)
//
//
// }
// }
// }
// if _all indexes limit the list of indexes to only these that we kept
// after looking at mappings
if indexNames == "_all" {
var newIndexes []string
for name := range idxs {
newIndexes = append(newIndexes, name)
}
indexNames = strings.Join(newIndexes, ",")
} else if strings.Contains(indexNames, "*") || strings.Contains(indexNames, "?") {
r, _ := regexp.Compile(indexNames)
//check index patterns
var newIndexes []string
for name := range idxs {
matched := r.MatchString(name)
if matched {
newIndexes = append(newIndexes, name)
}
}
indexNames = strings.Join(newIndexes, ",")
}
i := 0
// wrap in mappings if moving from super old es
for name, idx := range idxs {
i++
if _, ok := idx.(map[string]interface{})["mappings"]; !ok {
(idxs)[name] = map[string]interface{}{
"mappings": idx,
}
}
}
return indexNames, i, &idxs, nil
}
func getEmptyIndexSettings() map[string]interface{} {
tempIndexSettings := map[string]interface{}{}
tempIndexSettings["settings"] = map[string]interface{}{}
tempIndexSettings["settings"].(map[string]interface{})["index"] = map[string]interface{}{}
return tempIndexSettings
}
func cleanSettings(settings map[string]interface{}) {
//clean up settings
delete(settings["settings"].(map[string]interface{})["index"].(map[string]interface{}), "creation_date")
delete(settings["settings"].(map[string]interface{})["index"].(map[string]interface{}), "uuid")
delete(settings["settings"].(map[string]interface{})["index"].(map[string]interface{}), "version")
delete(settings["settings"].(map[string]interface{})["index"].(map[string]interface{}), "provided_name")
}
func (s *ESAPIV0) UpdateIndexSettings(name string, settings map[string]interface{}) error {
log.Debug("update index: ", name, settings)
cleanSettings(settings)
url := fmt.Sprintf("%s/%s/_settings", s.Host, name)
if _, ok := settings["settings"].(map[string]interface{})["index"]; ok {
if set, ok := settings["settings"].(map[string]interface{})["index"].(map[string]interface{})["analysis"]; ok {
log.Debug("update static index settings: ", name)
staticIndexSettings := getEmptyIndexSettings()
staticIndexSettings["settings"].(map[string]interface{})["index"].(map[string]interface{})["analysis"] = set
Post(fmt.Sprintf("%s/%s/_close", s.Host, name), s.Auth, "",s.HttpProxy)
body := bytes.Buffer{}
enc := json.NewEncoder(&body)
enc.Encode(staticIndexSettings)
bodyStr, err := Request("PUT", url, s.Auth, &body,s.HttpProxy)
if err != nil {
log.Error(bodyStr, err)
panic(err)
return err
}
delete(settings["settings"].(map[string]interface{})["index"].(map[string]interface{}), "analysis")
Post(fmt.Sprintf("%s/%s/_open", s.Host, name), s.Auth, "",s.HttpProxy)
}
}
log.Debug("update dynamic index settings: ", name)
body := bytes.Buffer{}
enc := json.NewEncoder(&body)
enc.Encode(settings)
_, err := Request("PUT", url, s.Auth, &body,s.HttpProxy)
return err
}
func (s *ESAPIV0) UpdateIndexMapping(indexName string, settings map[string]interface{}) error {
log.Debug("start update mapping: ", indexName,settings)
for name, mapping := range settings {
log.Debug("start update mapping: ", indexName,name,mapping)
url := fmt.Sprintf("%s/%s/%s/_mapping", s.Host, indexName, name)
body := bytes.Buffer{}
enc := json.NewEncoder(&body)
enc.Encode(mapping)
res, err := Request("POST", url, s.Auth, &body,s.HttpProxy)
if(err!=nil){
log.Error(url)
log.Error(body.String())
log.Error(err,res)
panic(err)
}
}
return nil
}
func (s *ESAPIV0) DeleteIndex(name string) (err error) {
log.Debug("start delete index: ", name)
url := fmt.Sprintf("%s/%s", s.Host, name)
Request("DELETE", url, s.Auth, nil,s.HttpProxy)
log.Debug("delete index: ", name)
return nil
}
func (s *ESAPIV0) CreateIndex(name string, settings map[string]interface{}) (err error) {
cleanSettings(settings)
body := bytes.Buffer{}
enc := json.NewEncoder(&body)
enc.Encode(settings)
log.Debug("start create index: ", name, settings)
url := fmt.Sprintf("%s/%s", s.Host, name)
resp, err := Request("PUT", url, s.Auth, &body,s.HttpProxy)
log.Debugf("response: %s",resp)
return err
}
func (s *ESAPIV0) Refresh(name string) (err error) {
log.Debug("refresh index: ", name)
url := fmt.Sprintf("%s/%s/_refresh", s.Host, name)
Post(url,s.Auth,"",s.HttpProxy)
return nil
}
func (s *ESAPIV0) NewScroll(indexNames string, scrollTime string, docBufferCount int,query string, slicedId,maxSlicedCount int, fields string) (scroll *Scroll, err error) {
// curl -XGET 'http://es-0.9:9200/_search?search_type=scan&scroll=10m&size=50'
url := fmt.Sprintf("%s/%s/_search?search_type=scan&scroll=%s&size=%d", s.Host, indexNames, scrollTime, docBufferCount)
jsonBody:=""
if len(query) > 0 || len(fields) > 0 {
queryBody := map[string]interface{}{}
if len(fields) > 0 {
if !strings.Contains(fields, ",") {
log.Error("The fields shoud be seraprated by ,")
return
} else {
queryBody["_source"] = strings.Split(fields, ",")
}
}
if len(query) > 0 {
queryBody["query"] = map[string]interface{}{}
queryBody["query"].(map[string]interface{})["query_string"] = map[string]interface{}{}
queryBody["query"].(map[string]interface{})["query_string"].(map[string]interface{})["query"] = query
jsonArray, err := json.Marshal(queryBody)
if err != nil {
log.Error(err)
} else {
jsonBody = string(jsonArray)
}
}
}
resp, body, errs := Post(url, s.Auth,jsonBody,s.HttpProxy)
if err != nil {
log.Error(errs)
return nil, errs[0]
}
io.Copy(ioutil.Discard, resp.Body)
defer resp.Body.Close()
log.Trace("new scroll,",url, body)
if err != nil {
log.Error(err)
return nil, err
}
if resp.StatusCode != 200 {
return nil, errors.New(body)
}
scroll = &Scroll{}
err = json.Unmarshal([]byte(body), scroll)
if err != nil {
log.Error(err)
return nil, err
}
return scroll, err
}
func (s *ESAPIV0) NextScroll(scrollTime string, scrollId string) (*Scroll, error) {
// curl -XGET 'http://es-0.9:9200/_search/scroll?scroll=5m'
id := bytes.NewBufferString(scrollId)
url := fmt.Sprintf("%s/_search/scroll?scroll=%s&scroll_id=%s", s.Host, scrollTime, id)
resp, body, errs := Get(url, s.Auth,s.HttpProxy)
if errs != nil {
log.Error(errs)
return nil, errs[0]
}
if resp.StatusCode != 200 {
return nil, errors.New(body)
}
io.Copy(ioutil.Discard, resp.Body)
defer resp.Body.Close()
log.Trace("next scroll,",url,body)
// decode elasticsearch scroll response
scroll := &Scroll{}
err := json.Unmarshal([]byte(body), &scroll)
if err != nil {
log.Error(body)
log.Error(err)
return nil, err
}
return scroll, nil
}
/*
Copyright 2016 Medcl (m AT medcl.net)
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package main
import (
"bytes"
log "github.com/cihub/seelog"
"encoding/json"
"fmt"
"errors"
"io"
"io/ioutil"
"strings"
)
type ESAPIV5 struct{
ESAPIV0
}
func (s *ESAPIV5) ClusterHealth() *ClusterHealth {
return s.ESAPIV0.ClusterHealth()
}
func (s *ESAPIV5) Bulk(data *bytes.Buffer){
s.ESAPIV0.Bulk(data)
}
func (s *ESAPIV5) GetIndexSettings(indexNames string) (*Indexes,error){
return s.ESAPIV0.GetIndexSettings(indexNames)
}
func (s *ESAPIV5) GetIndexMappings(copyAllIndexes bool,indexNames string)(string,int,*Indexes,error){
return s.ESAPIV0.GetIndexMappings(copyAllIndexes,indexNames)
}
func (s *ESAPIV5) UpdateIndexSettings(indexName string,settings map[string]interface{})(error){
return s.ESAPIV0.UpdateIndexSettings(indexName,settings)
}
func (s *ESAPIV5) DeleteIndex(name string) (err error) {
return s.ESAPIV0.DeleteIndex(name)
}
func (s *ESAPIV5) CreateIndex(name string,settings map[string]interface{}) (err error) {
return s.ESAPIV0.CreateIndex(name,settings)
}
func (s *ESAPIV5) UpdateIndexMapping(indexName string,settings map[string]interface{}) error {
return s.ESAPIV0.UpdateIndexMapping(indexName,settings)
}
func (s *ESAPIV5) Refresh(name string) (err error) {
return s.ESAPIV0.Refresh(name)
}
func (s *ESAPIV5) NewScroll(indexNames string,scrollTime string,docBufferCount int,query string, slicedId,maxSlicedCount int, fields string)(scroll *Scroll, err error){
url := fmt.Sprintf("%s/%s/_search?scroll=%s&size=%d", s.Host, indexNames, scrollTime,docBufferCount)
jsonBody:=""
if(len(query)>0||maxSlicedCount>0||len(fields)>0) {
queryBody := map[string]interface{}{}
if len(fields) > 0 {
if !strings.Contains(fields, ",") {
log.Error("The fields shoud be seraprated by ,")
return nil, errors.New("")
} else {
queryBody["_source"] = strings.Split(fields, ",")
}
}
if(len(query)>0){
queryBody["query"] = map[string]interface{}{}
queryBody["query"].(map[string]interface{})["query_string"] = map[string]interface{}{}
queryBody["query"].(map[string]interface{})["query_string"].(map[string]interface{})["query"] = query
}
if(maxSlicedCount>1){
log.Tracef("sliced scroll, %d of %d",slicedId,maxSlicedCount)
queryBody["slice"] = map[string]interface{}{}
queryBody["slice"].(map[string]interface{})["id"] = slicedId
queryBody["slice"].(map[string]interface{})["max"]= maxSlicedCount
}
jsonArray, err := json.Marshal(queryBody)
if (err != nil) {
log.Error(err)
}else{
jsonBody=string(jsonArray)
}
}
resp, body, errs := Post(url, s.Auth,jsonBody,s.HttpProxy)
if errs != nil {
log.Error(errs)
return nil,errs[0]
}
io.Copy(ioutil.Discard, resp.Body)
defer resp.Body.Close()
if resp.StatusCode != 200 {
return nil,errors.New(body)
}
log.Trace("new scroll,",body)
if err != nil {
log.Error(err)
return nil,err
}
scroll = &Scroll{}
err = json.Unmarshal([]byte(body),scroll)
if err != nil {
log.Error(err)
return nil,err
}
return scroll,err
}
func (s *ESAPIV5) NextScroll(scrollTime string,scrollId string)(*Scroll,error) {
id := bytes.NewBufferString(scrollId)
url:=fmt.Sprintf("%s/_search/scroll?scroll=%s&scroll_id=%s", s.Host, scrollTime, id)
resp,body, errs := Get(url,s.Auth,s.HttpProxy)
if errs != nil {
log.Error(errs)
return nil,errs[0]
}
io.Copy(ioutil.Discard, resp.Body)
defer resp.Body.Close()
if resp.StatusCode != 200 {
return nil,errors.New(body)
}
// decode elasticsearch scroll response
scroll := &Scroll{}
err:= json.Unmarshal([]byte(body), &scroll)
if err != nil {
log.Error(body)
log.Error(err)
return nil,err
}
return scroll,nil
}
This diff is collapsed.
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"#hide\n",
"from Iterative_masking.core import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Iterative_masking\n",
"> Supporting repository for: \"Generative power of a protein language model trained on multiple sequence alignments\" (preprint: https://doi.org/10.1101/2022.04.14.488405). We use MSA Transformer (https://doi.org/10.1101/2021.02.12.430858) to generate synthetic protein sequences by masking iteratively the same MSA."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Getting started\n",
"\n",
"Clone this repository on your local machine by running:\n",
"\n",
"```bash\n",
"git clone git@github.com:Bitbol-Lab/Iterative_masking.git\n",
"```\n",
"and move inside the root folder.\n",
"One can the use directly the functions from the cloned repository (in the folder `Iterative_masking`) or install it with an editable install running:\n",
"\n",
"```bash\n",
"pip install -e .\n",
"```\n",
"\n",
"We recommend creating and activating a dedicated ``conda`` or ``virtualenv`` Python virtual environment.\n",
"\n",
"## Requirements\n",
"In order to use the functions, the following python packages are required:\n",
"\n",
"- numpy\n",
"- scipy\n",
"- numba\n",
"- fastcore\n",
"- biopython\n",
"- esm==0.4.0\n",
"- pytorch\n",
"\n",
"It is also required to use a GPU (with cuda)."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## How to use"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"`IM_MSA_Transformer`: Class with different functions used to generate new MSAs with the iterative masking procedure\n",
"\n",
"`gen_MSAs`: example function (with parser) that can be used to generate and save new sequences directly from the terminal.\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# example on how to use `gen_MSAs` to replicate the results of the paper\n",
"\n",
"gen_MSAs(filepath=\"examples\",\n",
" filename=[\"PF00072.fasta\"],\n",
" new_dir=\"results\",\n",
" pdf=False,\n",
" T=1,\n",
" sample_all=False,\n",
" Iters=200,\n",
" pmask=0.1,\n",
" num=[600],\n",
" depth=1e10, #to do entire MSA\n",
" generate=False,\n",
" print_all=False,\n",
" range_vals=False,\n",
" phylo_w=False)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
from Iterative_masking.core import *
Iterative_masking.core.gen_MSAs(filepath="examples",
filename=["PF00072.fasta"],
new_dir="results",
pdf=False,
T=1,
sample_all=False,
Iters=200,
pmask=0.1,
num=[600],
depth=1e10, #to do entire MSA
generate=False,
print_all=False,
range_vals=False,
phylo_w=False)
File added
Creation of the directory /Iterative_masking-master/results_new failed
Tokenize
MSA Transformer model imported
[10] 1 1
Number of sequences in /Iterative_masking-master/examples/PF00072.fasta: 73062
MSA Imported
We are using batch MSAs of 10 sequences
MSA converted into tokens tensor of size and type:
torch.Size([1, 10, 113]) torch.int64
Generate Class
MSA Transformer model imported
[10] 1 1
Number of sequences in /Iterative_masking-master/examples/PF00072.fasta: 73062
MSA Imported
We are using batch MSAs of 10 sequences
MSA converted into tokens tensor of size and type:
torch.Size([1, 10, 113]) torch.int64
Compute results from Class
Generating MSA with same size as the original one
Successfully created the directory /Iterative_masking-master/results_new/Generated_iter-20_pmask-0.1_seqs-100_(only-masked-sampled)
---
### 1
int([x]) -> integer
int(x, base=10) -> integer
Convert a number or string to an integer, or return 0 if no arguments
are given. If x is a number, return x.__int__(). For floating point
numbers, this truncates towards zero.
If x is not a number or if base is given, then x must be a string,
bytes, or bytearray instance representing an integer literal in the
given base. The literal can be preceded by '+' or '-' and be surrounded
by whitespace. The base defaults to 10. Valid bases are 0 and 2-36.
Base 0 means to interpret the base from the string as an integer literal.
>>> int('0b100', base=0)
4
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment